#!/usr/bin/env python3
import os
import numpy as np
import torch
from torchvision import transforms
from tqdm import tqdm
import sys 
sys.path.append("./mindeye2_src")
import utils
import random

torch.backends.cuda.matmul.allow_tf32 = True
utils.seed_everything(42)

# ---------- 基本设置 ----------
device = "cuda:0"
subj=1
eye = ''

# ---------- 开关 ----------
pix_corr   = True
ssim       = True
alexnet    = True
inception3 = True
clip       = True
eff        = True
swav       = True

# ---------- 累积结果 ----------
pix_corr_v   = 0.0
ssim_v       = 0.0
alexnet2_v   = 0.0
alexnet5_v   = 0.0
inception3_v = 0.0
clip_v       = 0.0
eff_v        = 0.0
swav_v       = 0.0

# ---------- 主循环 ----------
with torch.no_grad():
    # 1. 加载 1000 张真实图 & 重建图
    all_images = torch.load(
        f"./subj0{subj}_prior_vae_eye/test_images_1000.pt",
        map_location=device
    )                      # (1000, C, H, W)
    all_recons = torch.load(
        f"./subj0{subj}_prior_vae{eye}/all_recons.pt",
        map_location=device
    )                      # (1000, C, H, W)
    
    print("all_recons.shape: ",all_recons.shape)

    # ---------- 1. Pixel-wise correlation ----------
    if pix_corr:
        preprocess = transforms.Compose([
            transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
        ])
        flat_real = preprocess(all_images).reshape(len(all_images), -1).cpu()
        flat_recon = preprocess(all_recons).view(len(all_recons), -1).cpu()
        corrsum = 0.0
        for i in tqdm(range(len(all_images)), desc="pix-corr"):
            corrsum += np.corrcoef(flat_real[i], flat_recon[i])[0, 1]
        pixcorr = corrsum / len(all_images)
        print(f"Pixel Corr: {pixcorr:.4f}")
        pix_corr_v += pixcorr

    # ---------- 2. SSIM ----------
    if ssim:
        from skimage.color import rgb2gray
        from skimage.metrics import structural_similarity as ssim_fn
        preprocess = transforms.Compose([
            transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
        ])
        real_np = rgb2gray(preprocess(all_images).permute(0, 2, 3, 1).cpu())
        recon_np = rgb2gray(preprocess(all_recons).permute(0, 2, 3, 1).cpu())
        ssim_scores = [
            ssim_fn(rec, real, multichannel=True, gaussian_weights=True,
                    sigma=1.5, use_sample_covariance=False, data_range=1.0)
            for real, rec in tqdm(zip(real_np, recon_np), total=len(all_images), desc="SSIM")
        ]
        ssim_mean = np.mean(ssim_scores)
        print(f"SSIM: {ssim_mean:.4f}")
        ssim_v += ssim_mean

    # ---------- 3. 两路识别 ----------
    from torchvision.models.feature_extraction import create_feature_extractor

    from torchvision.transforms.functional import to_pil_image

    @torch.no_grad()
    def two_way(recons, images, model, preprocess, feat_layer=None):
        # 判断是不是 CLIP：看 model 类型
        is_clip = hasattr(model, '__name__') or str(type(model)).find('clip') > 0

        def _make_batch(tensor_imgs):
            if is_clip:                     # CLIP 需要 PIL
                from torchvision.transforms.functional import to_pil_image
                pil_imgs = [to_pil_image(img.cpu()) for img in tensor_imgs]
                return torch.stack([preprocess(p) for p in pil_imgs], dim=0).to(device)
            else:                           # TorchVision 需要 Tensor
                return torch.stack([preprocess(img) for img in tensor_imgs], dim=0).to(device)

        preds = model(_make_batch(recons))
        reals = model(_make_batch(images))

        if feat_layer is None:
            preds = preds.float().flatten(1).cpu().numpy()
            reals = reals.float().flatten(1).cpu().numpy()
        else:
            preds = preds[feat_layer].float().flatten(1).cpu().numpy()
            reals = reals[feat_layer].float().flatten(1).cpu().numpy()

        r = np.corrcoef(reals, preds)
        r = r[:len(images), len(images):]
        congruents = np.diag(r)
        success = r < congruents
        return np.mean(np.sum(success, axis=0)) / (len(images) - 1)

    # 3-a. AlexNet (early & mid)
    if alexnet:
        from torchvision.models import alexnet, AlexNet_Weights
        alex_model = create_feature_extractor(
            alexnet(weights=AlexNet_Weights.IMAGENET1K_V1),
            return_nodes=['features.4', 'features.11']
        ).to(device).eval().requires_grad_(False)
        preprocess_alex = transforms.Compose([
            transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        alexnet2 = two_way(all_recons, all_images, alex_model, preprocess_alex, 'features.4')
        alexnet5 = two_way(all_recons, all_images, alex_model, preprocess_alex, 'features.11')
        print(f"AlexNet-2: {alexnet2:.4f}, AlexNet-5: {alexnet5:.4f}")
        alexnet2_v += alexnet2
        alexnet5_v += alexnet5

    # 3-b. Inception-v3
    if inception3:
        from torchvision.models import inception_v3, Inception_V3_Weights
        inc_model = create_feature_extractor(
            inception_v3(weights=Inception_V3_Weights.DEFAULT),
            return_nodes=['avgpool']
        ).to(device).eval().requires_grad_(False)
        preprocess_inc = transforms.Compose([
            transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        inception = two_way(all_recons, all_images, inc_model, preprocess_inc, 'avgpool')
        print(f"Inception-v3: {inception:.4f}")
        inception3_v += inception

    # 3-c. CLIP
    if clip:
        import clip
        clip_model, preprocess_clip = clip.load("ViT-L/14", device=device)
        clip_acc = two_way(all_recons, all_images, clip_model.encode_image, preprocess_clip, None)
        print(f"CLIP: {clip_acc:.4f}")
        clip_v += clip_acc

    # 3-d. EfficientNet
    if eff:
        import scipy.spatial.distance as spdist
        from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
        eff_model = create_feature_extractor(
            efficientnet_b1(weights=EfficientNet_B1_Weights.DEFAULT),
            return_nodes=['avgpool']
        ).to(device).eval().requires_grad_(False)
        preprocess_eff = transforms.Compose([
            transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        real_feat = eff_model(preprocess_eff(all_images))['avgpool'].reshape(len(all_images), -1).cpu().numpy()
        recon_feat = eff_model(preprocess_eff(all_recons))['avgpool'].reshape(len(all_recons), -1).cpu().numpy()
        eff_dist = np.mean([spdist.correlation(real_feat[i], recon_feat[i]) for i in range(len(all_images))])
        print(f"EfficientNet (corr-dist): {eff_dist:.4f}")
        eff_v += eff_dist

    # 3-e. SwAV
    if swav:
        import scipy.spatial.distance as spdist
        swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
        swav_model = create_feature_extractor(swav_model, return_nodes=['avgpool']).to(device).eval().requires_grad_(False)
        preprocess_swav = transforms.Compose([
            transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        real_feat = swav_model(preprocess_swav(all_images))['avgpool'].reshape(len(all_images), -1).cpu().numpy()
        recon_feat = swav_model(preprocess_swav(all_recons))['avgpool'].reshape(len(all_recons), -1).cpu().numpy()
        swav_dist = np.mean([spdist.correlation(real_feat[i], recon_feat[i]) for i in range(len(all_images))])
        print(f"SwAV (corr-dist): {swav_dist:.4f}")
        swav_v += swav_dist

# ---------- 汇总 ----------
print(f"Pixel Corr : {pix_corr_v:.4f}")
print(f"SSIM       : {ssim_v:.4f}")
print(f"AlexNet-2  : {alexnet2_v:.4f}")
print(f"AlexNet-5  : {alexnet5_v:.4f}")
print(f"Inception3 : {inception3_v:.4f}")
print(f"CLIP       : {clip_v:.4f}")
print(f"EfficientNet(corr): {eff_v:.4f}")
print(f"SwAV(corr) : {swav_v:.4f}") 

# ---------- 保存 ----------
import json
results = {
    "pixel_corr": float(pix_corr_v),
    "ssim":       float(ssim_v),
    "alexnet2":   float(alexnet2_v),
    "alexnet5":   float(alexnet5_v),
    "inception3": float(inception3_v),
    "clip":       float(clip_v),
    "efficientnet": float(eff_v),
    "swav":       float(swav_v),
}
save_path = f"./subj0{subj}_prior_vae{eye}/eval_results.json"
with open(save_path, "w") as f:
    json.dump(results, f, indent=4)
print(f"✅ 评估结果已保存: {save_path}")